import torch
from torch.utils.data import Dataset

class PowerImpedanceDataset(Dataset):
    """
    自定义数据集类，用于加载电压和功率阻抗数据。
    """

    def __init__(self, voltage_data, power_impedance_data):
        """
        初始化方法
        :param voltage_data: 电压数据，形状为 (num_samples, 2, 9)
        :param power_impedance_data: 功率和阻抗数据，形状为 (num_samples, 2, 9)
        """
        self.voltage_data = voltage_data
        self.power_impedance_data = power_impedance_data

    def __len__(self):
        """
        返回数据集的大小（样本数量）
        """
        return len(self.voltage_data)

    def __getitem__(self, idx):
        """
        获取索引 idx 对应的数据样本
        :param idx: 数据索引
        :return: 返回电压数据和功率阻抗数据
        """
        voltage = self.voltage_data[idx]
        power_impedance = self.power_impedance_data[idx]
        return voltage, power_impedance


class VoltageGraphDataset(Dataset):
    def __init__(self, graph_list, voltage_tensor, transform=None):
        """
        初始化数据集。
        
        参数:
        - graph_list (list of Data): 包含图数据的列表，每个元素是一个 PyG Data 对象。
        - voltage_tensor (torch.Tensor): 电压数据张量，形状为 [72, 17]。
        - transform (callable, optional): 可选的转换函数，将应用于每个图数据。
        """
        assert len(graph_list) == voltage_tensor.size(0), "图列表和电压张量的长度不匹配。"
        self.graph_list = graph_list
        self.voltage_tensor = voltage_tensor
        self.transform = transform

    def __len__(self):
        """
        返回数据集的大小。
        """
        return len(self.graph_list)

    def __getitem__(self, idx):
        """
        获取数据集中的一个样本。
        
        参数:
        - idx (int): 索引。
        
        返回:
        - graph (Data): 图数据。
        - voltage (torch.Tensor): 对应的电压数据，形状为 [17]。
        """
        graph = self.graph_list[idx]
        voltage = self.voltage_tensor[idx]
        
        if self.transform:
            graph = self.transform(graph)
        
        return graph, voltage

    def print_shapes(self):
        """
        打印图数据列表和电压张量的形状。
        """
        print(f"Graph list length: {len(self.graph_list)}")
        print(f"Voltage tensor shape: {self.voltage_tensor.shape}")
